import argparse
import configparser
import os
import sys
import time
root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
sys.path.append(root_dir)
import torch
from tatk.nlg.sclstm.multiwoz.loader.dataset_woz import DatasetWoz, SimpleDatasetWoz
from tatk.nlg.sclstm.model.lm_deep import LMDeep

USE_CUDA = torch.cuda.is_available()


def score(feat, gen, template):
	'''
	feat = ['d-a-s-v:Booking-Book-Day-1', 'd-a-s-v:Booking-Book-Name-1', 'd-a-s-v:Booking-Book-Name-2']
	gen = 'xxx slot-booking-book-name xxx slot-booking-book-time'
	'''
	das = [] # e.g. a list of d-a-s-v:Booking-Book-Day
	with open(template) as f:
		for line in f:
			if 'd-a-s-v:' not in line:
				continue
			if '-none' in line or '-?' in line or '-yes' in line or '-no' in line:
				continue
			tok = '-'.join(line.strip().split('-')[:-1])
			if tok not in das:
				das.append(tok)

	total, redunt, miss = 0, 0, 0
	for _das in das:
		feat_count = 0
		das_order = [ _das+'-'+str(i) for i in range(20) ]
		for _feat in feat:
			if _feat in das_order:
				feat_count += 1
		_das = _das.replace('d-a-s-v:', '').lower().split('-')
		slot_tok = '@' + _das[0][:3] + '-' + _das[1] + '-' + _das[2]

		gen_count = gen.split().count(slot_tok)
		diff_count = gen_count - feat_count
		if diff_count > 0:
			redunt += diff_count
		else:
			miss += -diff_count
		total += feat_count
	return total, redunt, miss


def get_slot_error(dataset, gens, refs, sv_indexes):
	'''
	Args:
		gens:  (batch_size, beam_size)
		refs:  (batch_size,)
		sv:    (batch_size,)
	Returns:
		count: accumulative slot error of a batch
		countPerGen: slot error for each sample
	'''
	batch_size = len(gens)
	beam_size = len(gens[0])
	assert len(refs) == batch_size and len(sv_indexes) == batch_size

	count = {'total': 0.0, 'redunt': 0.0, 'miss': 0.0}
	countPerGen = [ [] for _ in range(batch_size) ]
	for batch_idx in range(batch_size):
		for beam_idx in range(beam_size):
			felements = [dataset.cardinality[x+dataset.dfs[2]] for x in sv_indexes[batch_idx]]

			# get slot error per sample(beam)
			total, redunt, miss = score(felements, gens[batch_idx][beam_idx], dataset.template)

			c = {}
			for a, b in zip(['total', 'redunt', 'miss'], [total, redunt, miss]):
				c[a] = b
				count[a] += b
			countPerGen[batch_idx].append(c)

	return count, countPerGen
		

def evaluate(config, dataset, model, data_type, beam_search, beam_size, batch_size):
	t = time.time()
	model.eval()

	total_loss = 0
	countAll = {'total': 0.0, 'redunt': 0.0, 'miss': 0.0}
	for i in range(dataset.n_batch[data_type]):
		input_var, label_var, feats_var, lengths, refs, featStrs, sv_indexes = dataset.next_batch(data_type=data_type)

		if data_type == 'valid':
			# feed-forward w/i ground truth as input
			decoded_words = model(input_var, dataset, feats_var, gen=False, beam_search=False, beam_size=1)

			# update loss
			loss = model.get_loss(label_var, lengths)
			total_loss += loss.item()

			# run generation for calculating slot error 
			decoded_words = model(input_var, dataset, feats_var, gen=True, beam_search=False, beam_size=1)
			countBatch, countPerGen = get_slot_error(dataset, decoded_words, refs, sv_indexes)

		else: # test
			print('decode batch {} out of {}'.format(i, dataset.n_batch[data_type]), file=sys.stderr)
			decoded_words = model(input_var, dataset, feats_var, gen=True, beam_search=beam_search, beam_size=beam_size)
			countBatch, countPerGen = get_slot_error(dataset, decoded_words, refs, sv_indexes)
	
			# print generation results
			for batch_idx in range(batch_size):
				print('Feat: {}'.format(featStrs[batch_idx]))
				print('Target: {}'.format(refs[batch_idx]))
				for beam_idx in range(beam_size):
					c = countPerGen[batch_idx][beam_idx]
					s = decoded_words[batch_idx][beam_idx]
					print('Gen{} ({},{},{}): {}'.format(beam_idx, c['redunt'], c['miss'], c['total'], s))
				print('-----------------------------------------------------------')

		# accumulate slot error across batches
		for _type in countAll:
			countAll[_type] += countBatch[_type]

	total_loss /= dataset.n_batch[data_type]

	se = (countAll['redunt'] + countAll['miss']) / countAll['total'] * 100
	print('{} Loss: {:.3f} | Slot error: {:.3f} | Time: {:.1f}'.format(data_type, total_loss, se, time.time()-t))
	print('{} Loss: {:.3f} | Slot error: {:.3f} | Time: {:.1f}'.format(data_type, total_loss, se, time.time()-t), file=sys.stderr)
	print('redunt: {}, miss: {}, total: {}'.format(countAll['redunt'], countAll['miss'], countAll['total']))
	print('redunt: {}, miss: {}, total: {}'.format(countAll['redunt'], countAll['miss'], countAll['total']), file=sys.stderr)
	return total_loss


def train_epoch(config, dataset, model):
	t = time.time()
	model.train()

	total_loss = 0
	for i in range(dataset.n_batch['train']):
		input_var, label_var, feats_var, lengths, refs, featStrs, sv_indexes = dataset.next_batch()

		# feedforward and calculate loss
		_ = model(input_var, dataset, feats_var)
		loss = model.get_loss(label_var, lengths)

		# update loss
		total_loss += loss.item()

		# update model
		model.update(config.getfloat('MODEL', 'clip'))

	total_loss /= dataset.n_batch['train']

	print('Train Loss: {:.3f} | Time: {:.1f}'.format(total_loss, time.time()-t))
	print('Train Loss: {:.3f} | Time: {:.1f}'.format(total_loss, time.time()-t), file=sys.stderr)


def read(config, args, mode):
	# get data
	print('Processing data...', file=sys.stderr)
	
	# TODO: remove this constraint
	if mode == 'test' and args.beam_search:
		print('Set batch_size to 1 due to beam search')
		print('Set batch_size to 1 due to beam search', file=sys.stderr)

	percentage = args.percent
	dataset = DatasetWoz(config, percentage=percentage, use_cuda=USE_CUDA)

	# get model hyper-parameters
	n_layer = args.n_layer
	hidden_size = config.getint('MODEL', 'hidden_size')
	dropout = config.getfloat('MODEL', 'dropout')
	lr = args.lr
	beam_size = args.beam_size

	# get feat size
	d_size = dataset.do_size + dataset.da_size + dataset.sv_size # len of 1-hot feat
	vocab_size = len(dataset.word2index)

	model_path = args.model_path
	model = LMDeep('sclstm', vocab_size, vocab_size, hidden_size, d_size, n_layer=n_layer, dropout=dropout, lr=lr, use_cuda=USE_CUDA)

	if mode == 'train':
		print('do not support re-train model, please make sure the model do not exist before training')
		assert not os.path.isfile(model_path)

	else: # mode = 'test'
		# load model
		print(model_path, file=sys.stderr)
		assert os.path.isfile(model_path)
		model.load_state_dict(torch.load(model_path))
		if mode != 'adapt':
			model.eval()

	# Print model info
	print('\n***** MODEL INFO *****')
	print('MODEL PATH:', model_path)
	print('SIZE OF HIDDEN:', hidden_size)
	print('# of LAYER:', n_layer)
	print('SAMPLE/BEAM SIZE:', beam_size)
	print('*************************\n')

	# Move models to GPU
	if USE_CUDA:
		model.cuda()

	return dataset, model

	
def train(config, args):
	dataset, model = read(config, args, args.mode)
	n_layer = args.n_layer
	model_path = args.model_path
	
	# Start training
	epoch = 0
	best_loss = 10000
	print('Start training', file=sys.stderr)
	while epoch < config.getint('TRAINING', 'n_epochs'):
		dataset.reset()
		print('Epoch', epoch, '(n_layer {})'.format(n_layer))
		print('Epoch', epoch, '(n_layer {})'.format(n_layer), file=sys.stderr)
		train_epoch(config, dataset, model)

		loss = evaluate(config, dataset, model, 'valid', False, 1, dataset.batch_size)

		if loss < best_loss:
			earlyStop = 0
			# save model
			print('Best loss: {:.3f}, AND Save model!'.format(loss))
			print('Best loss: {:.3f}, AND Save model!'.format(loss), file=sys.stderr)
			torch.save(model.state_dict(), model_path)
			best_loss = loss
		else:
			earlyStop += 1

		if earlyStop == 6: # do not improve on dev set 10 epoches in a row
			return
		epoch += 1
		print('----------------------------------------')
		print('----------------------------------------', file=sys.stderr)


def test(config, args):
	dataset, model = read(config, args, 'test')
	data_type = args.mode
	beam_search = args.beam_search
	beam_size = args.beam_size

	print('TEST ON: {}'.format(data_type))
	print('TEST ON: {}'.format(data_type), file=sys.stderr)

	# Evaluate model
	loss = evaluate(config, dataset, model, 'test', beam_search, beam_size, dataset.batch_size)


def interact(config, args):
    dataset = SimpleDatasetWoz(config)

    # get model hyper-parameters
    n_layer = args.n_layer
    hidden_size = config.getint('MODEL', 'hidden_size')
    dropout = config.getfloat('MODEL', 'dropout')
    lr = args.lr
    beam_size = args.beam_size
    
    # get feat size
    d_size = dataset.do_size + dataset.da_size + dataset.sv_size # len of 1-hot feat
    vocab_size = len(dataset.word2index)

    model = LMDeep('sclstm', vocab_size, vocab_size, hidden_size, d_size, n_layer=n_layer, dropout=dropout, lr=lr)
    model_path = args.model_path    
    print(model_path)
    assert os.path.isfile(model_path)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # Move models to GPU
    if USE_CUDA:
        model.cuda()
    
    data_type = args.mode
    print('INTERACT ON: {}'.format(data_type))
    
    example = {"Attraction-Inform": [["Choice","many"],["Area","centre of town"]],
               "Attraction-Select": [["Type","church"],["Type"," swimming"],["Type"," park"]]}    
    for k, v in example.items():
        counter = {}
        for pair in v:
            if pair[0] in counter:
                counter[pair[0]] += 1
            else:
                counter[pair[0]] = 1
            pair.insert(1, str(counter[pair[0]]))
    
    do_idx, da_idx, sv_idx, featStr = dataset.getFeatIdx(example)
    do_cond = [1 if i in do_idx else 0 for i in range(dataset.do_size)] # domain condition
    da_cond = [1 if i in da_idx else 0 for i in range(dataset.da_size)] # dial act condition
    sv_cond = [1 if i in sv_idx else 0 for i in range(dataset.sv_size)] # slot/value condition
    feats = [do_cond + da_cond + sv_cond]
    
    feats_var = torch.FloatTensor(feats)
    if USE_CUDA:
        feats_var = feats_var.cuda()
	
    decoded_words = model.generate(dataset, feats_var, beam_size)
    delex = decoded_words[0] # (beam_size)
        
    recover = []
    for sen in delex:
        counter = {}
        words = sen.split()
        for word in words:
            if word.startswith('slot-'):
                flag = True
                _, domain, intent, slot_type = word.split('-')
                da = domain.capitalize() + '-' + intent.capitalize()
                if da in example:
                    key = da + '-' + slot_type.capitalize()
                    for pair in example[da]:
                        if (pair[0].lower() == slot_type) and ((key not in counter) or (counter[key] == int(pair[1]) - 1)):
                            sen = sen.replace(word, pair[2], 1)
                            counter[key] = int(pair[1])
                            flag = False
                            break
                if flag:
                    sen = sen.replace(word, '', 1)
        recover.append(sen)
    
    print('meta', example)
    for i in range(beam_size):
        print(i, delex[i])
        print(i, recover[i])


def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def parse():
	parser = argparse.ArgumentParser(description='Train dialogue generator')
	parser.add_argument('--mode', type=str, default='interact', help='train or test')
	parser.add_argument('--model_path', type=str, default='sclstm.pt', help='saved model path')
	parser.add_argument('--n_layer', type=int, default=1, help='# of layers in LSTM')
	parser.add_argument('--percent', type=float, default=1, help='percentage of training data')
	parser.add_argument('--beam_search', type=str2bool, default=False, help='beam_search')
	parser.add_argument('--attn', type=str2bool, default=True, help='whether to use attention or not')
	parser.add_argument('--beam_size', type=int, default=10, help='number of generated sentences')
	parser.add_argument('--bs', type=int, default=256, help='batch size')
	parser.add_argument('--lr', type=float, default=0.0025, help='learning rate')
	parser.add_argument('--user', type=str2bool, default=False, help='use user data')
	args = parser.parse_args()

	config = configparser.ConfigParser()
	if args.user:
		config.read('config/config_usr.cfg')
	else:
		config.read('config/config.cfg')
	config.set('DATA','dir', os.path.dirname(os.path.abspath(__file__)))
	
	return args, config


if __name__ == '__main__':
	# set seed for reproducing
#	random.seed(1235)
#	torch.manual_seed(1235)
#	torch.cuda.manual_seed_all(1235)

	args, config = parse()

	# training
	if args.mode == 'train' or args.mode == 'adapt':
		train(config, args)
	# test
	elif args.mode == 'test':
		test(config, args)
	# interact
	elif args.mode == 'interact':
		interact(config, args)
	else:
		print('mode cannot be recognized')
		sys.exit(1)
